-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BYOC][TRT] Support batch norm for all ranks <=5, and all axes #7026
Conversation
ae0c87b
to
27982b8
Compare
All tests in test_tensorrt.py passed locally |
c03ece8
to
f51f808
Compare
auto input_dims = TrtDimsToVector(input->getDimensions()); | ||
const size_t min_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; | ||
const size_t max_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 4 : 5; | ||
ICHECK_LE(input_dims.size(), max_rank); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you convert these checks to use Diagnostic
instead of generating an assertion, we should strive to replace most of these with end-user readable errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jroesch, thanks for reviewing!
These checks are more for sanity checking, since the annotation functions in python/tvm/relay/op/contrib/tensorrt.py will filter out the unsupported ops before they ever get to this code. I don't expect users to ever see these.
Anyway, I can make a separate PR to port all of the ICHECK to Diagnostics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Let's have a separate PR to migrate all errors to diagnostic.
Thanks @trevor-m @anijain2305 @jroesch |
…e#7026) * [TRT] Support batch norm for all ranks <=5, and all axis * Add return false * Fix TRT < 6 build
…e#7026) * [TRT] Support batch norm for all ranks <=5, and all axis * Add return false * Fix TRT < 6 build
…e#7026) * [TRT] Support batch norm for all ranks <=5, and all axis * Add return false * Fix TRT < 6 build
Previous batch norm only supported rank 4 inputs with axis 1 or 3. Now we support input ranks and axes 1-5.